/*
 * Copyright (c) Hisilicon Technologies Co., Ltd. 2022-2022. All rights reserved.
 * Description: 设置轻量化配置
 */
#include "quantize_optimizer.h"

#include <utility>
#include <string>
#include <map>

#include "graph/op/all_ops.h"
#include "infra/base/securestl.h"
#include "infra/base/assertion.h"
#include "common/math/math_util.h"
#include "omg/params.h"
#include "omg/quantize_optimizer/quantize_saver.h"
#include "omg/quantize_optimizer/quantize_cfg_parser.h"
#include "omg/quantize_optimizer/quantize_util.h"
#include "omg/quantize_optimizer/quantize_math_util.h"
#include "omg/quantize_optimizer/quantize_types.h"
#include "framework/common/op_types.h"
#include "framework/infra/log/log.h"
#include "framework/graph/core/cgraph/graph_finder.h"
#include "framework/graph/core/cgraph/graph_list_walker.h"
#include "framework/graph/core/cgraph/graph_modifier.h"
#include "framework/graph/core/cgraph/graph_sorter.h"
#include "framework/graph/core/cgraph/graph_spec.h"
#include "framework/graph/core/node/node_spec.h"
#include "framework/graph/core/node/node_walker.h"
#include "framework/graph/op/internal_nn_defs.h"
#include "framework/graph/utils/attr_utils.h"
#include "framework/graph/utils/graph_utils.h"
#include "graph/buffer.h"

using namespace std;
using namespace ge;

namespace hiai {
namespace {
bool CheckQuantizeInfosV1(const std::map<std::string, QuantizeConfig>& quantizeConfigs, ge::ComputeGraphPtr graph)
{
    map<string, QuantizeConfig>::const_iterator it = quantizeConfigs.cbegin();
    for (; it != quantizeConfigs.cend(); it++) {
        auto nodeVisitor = [&it](Node& node) {
            const string& opName = it->first;
            if (node.ROLE(NodeSpec).Name() != it->first) {
                return hiai::SUCCESS;
            }
            if (QuantizeUtil::CheckOpQuantizeInfo(node, it->second) != hiai::SUCCESS) {
                FMK_LOGE("Check quantize infos fail, op name:%s", opName.c_str());
                return hiai::FAILURE;
            }
            return hiai::SUCCESS;
        };
        HIAI_EXPECT_EXEC_R(GraphUtils::WalkAllSubGraphNodes(*graph, nodeVisitor), false);
    }
    return true;
}

bool CheckScaleOffsetSize(const std::vector<QuantizeParams>& quantParams)
{
    bool comResult = true;
    for (auto it = quantParams.cbegin(); it != quantParams.cend(); it++) {
        for (auto it2 = it->operatorParams.cbegin(); it2 != it->operatorParams.cend(); it2++) {
            if (it2->operType == OperatorType::REQUANTIZE) {
                continue;  // Requantize的offset数量可能与scale数量不一致，取决于下一个算子
            }
            if (it2->scale.size() != it2->offset.size()) {
                FMK_LOGE("Scale size and offset size is not equal.operType:%d", it2->operType);
                comResult = false;
                break;
            }
        }
    }
    return comResult;
}

bool CheckQuantizeInfosV2(const std::map<std::string, QuantizeV2Config>& quantizeV2Configs, ge::ComputeGraph& graph)
{
    bool checkRet = true;
    for (auto it = quantizeV2Configs.cbegin(); it != quantizeV2Configs.cend(); it++) {
        if (!CheckScaleOffsetSize(it->second.inputQuantParams) || !CheckScaleOffsetSize(it->second.outputQuantParams)) {
            FMK_LOGE("Op:%s Quantize config invalid!", it->first.c_str());
            checkRet = false;
            break;
        }
        std::string nodeName = it->first;
        ge::Node* node = graph.ROLE(GraphFinder).FindNode(nodeName);
        if (node == nullptr) {
            GraphUtils::WalkAllSubGraphs(graph, [&nodeName, &node](ge::ComputeGraphPtr& subGraph) {
                if (node == nullptr) {
                    node = subGraph->ROLE(GraphFinder).FindNode(nodeName);
                }
                return hiai::SUCCESS;
            });
        }
        if (node == nullptr) {
            (void)graph.ROLE(GraphListWalker).WalkAllNodes([&nodeName, &node](ge::Node &tNode) {
                std::vector<std::string> originalOpNames;
                auto& opDest = tNode.ROLE(NodeSpec).OpDesc();
                (void)ge::AttrUtils::GetListStr(opDest, "original_op_names", originalOpNames);
                for (const std::string &originalName : originalOpNames) {
                    if (nodeName == originalName) {
                        node = &tNode;
                        return hiai::SUCCESS;
                    }
                }
                return hiai::SUCCESS;
            });
        }
        HIAI_EXPECT_NOT_NULL_R(node, false);
        for (uint32_t i = 0; i < it->second.inputQuantParams.size(); i++) {
            uint32_t inputIndex = it->second.inputQuantParams[i].index;
            ge::Node* peerInNode = node->ROLE(NodeWalker).InDataNode(inputIndex);
            // bias权值的量化参数是根据input和filter的量化参数计算出来的，bias可能不存在
            if (it->second.inputQuantParams[i].operatorParams.size() == 1 &&
                it->second.inputQuantParams[i].operatorParams[0].dataType == ge::DT_INT32) {
                continue;
            }
            HIAI_EXPECT_NOT_NULL_R(peerInNode, false);
        }
        HIAI_EXPECT_TRUE_R(it->second.outputQuantParams.size() <= 1, false);
        if (it->second.outputQuantParams.size() == 1) {
            HIAI_EXPECT_TRUE_R(it->second.outputQuantParams[0].index == 0, false);
        }
    }
    return checkRet;
}

bool CheckQuantizeInfos(const ModelLightWeightParams& params, ge::ComputeGraphPtr graph)
{
    if (params.version == "") {
        return CheckQuantizeInfosV1(params.quantizeConfigs, graph);
    } else {
        return CheckQuantizeInfosV2(params.quantizeV2Configs, *graph);
    }
}

Status GetQuantizeOps(ge::ComputeGraphPtr graph, vector<Node*>& quantOps)
{
    auto getQuantOps = [&quantOps](Node& node) {
        if (!QuantizeUtil::IsSupportQuantOpType(node.ROLE(NodeSpec).Type())) {
            return hiai::SUCCESS;
        }
        quantOps.push_back(&node);
        return hiai::SUCCESS;
    };
    HIAI_EXPECT_EXEC(GraphUtils::WalkAllSubGraphNodes(*graph, getQuantOps));
    return hiai::SUCCESS;
}

string GetQuantizeOpName(Node* node, bool useWeightName)
{
    string compOpName = node->ROLE(NodeSpec).Name();
    if (useWeightName) {
        vector<string> inputNames = OpDescUtils::GetConstInputNames(*node);
        if (inputNames.size() == 0) {
            FMK_LOGE("Get input name from op:%s failed.", node->ROLE(NodeSpec).Name().c_str());
            return compOpName;
        }
        uint32_t index = QuantizeUtil::GetFilterIndex(node);
        compOpName = inputNames[index];
    }
    return compOpName;
}

Status SetQuantizeInfosV1(ModelLightWeightParams& params, int64_t weightDataAddr, ge::ComputeGraphPtr graph)
{
    vector<Node*> supportQuantOps;
    HIAI_EXPECT_EXEC(GetQuantizeOps(graph, supportQuantOps));

    bool hasSeted = false;
    for (Node* node : supportQuantOps) {
        string compOpName = GetQuantizeOpName(node, params.useWeightName);
        if (params.quantizeConfigs.find(compOpName) == params.quantizeConfigs.end()) {
            continue;
        }
        QuantizeConfig& quantizeConfig = params.quantizeConfigs[compOpName];
        if (QuantizeSaver::SaveOpQuantV1Params(quantizeConfig, node, weightDataAddr) != hiai::SUCCESS) {
            FMK_LOGE("Node: %s save quantize parameters failed.", compOpName.c_str());
            return hiai::FAILED;
        }

        bool isOneSideQuantize =
            (quantizeConfig.inputDataType == DT_FLOAT16 || quantizeConfig.inputDataType == DT_FLOAT) &&
            (quantizeConfig.weightDataType == DT_INT8);
        if (isOneSideQuantize && !hasSeted) {
            // 设置单边量化标志
            (void)QuantizeUtil::SetOneSideQuantize(*graph, true);
            hasSeted = true;
        }
        params.quantizeConfigs.erase(compOpName);
    }
    if (!params.quantizeConfigs.empty()) {
        FMK_LOGE("Node: %s has quantize params, but not in the graph.", params.quantizeConfigs.begin()->first.c_str());
        return hiai::FAILED;
    }

    return hiai::SUCCESS;
}

bool CacheAllQuantizeNodes(const ge::ComputeGraph& graph,
    const std::map<std::string, QuantizeV2Config>& quantizeV2Configs, std::map<std::string, ge::Node*>& targetNodes)
{
    std::map<std::string, QuantizeV2Config>::const_iterator it = quantizeV2Configs.cbegin();
    for (; it != quantizeV2Configs.cend(); it++) {
        std::string nodeName = it->first;
        ge::Node* node = graph.ROLE(GraphFinder).FindNode(nodeName);
        if (node == nullptr) {
            GraphUtils::WalkAllSubGraphs(graph, [&nodeName, &node](ge::ComputeGraphPtr& subGraph) {
                if (node == nullptr) {
                    node = subGraph->ROLE(GraphFinder).FindNode(nodeName);
                }
                return hiai::SUCCESS;
            });
        }
        if (node == nullptr) {
            (void)graph.ROLE(GraphListWalker).WalkAllNodes([&nodeName, &node](ge::Node &tNode) {
                std::vector<std::string> originalOpNames;
                auto& opDest = tNode.ROLE(NodeSpec).OpDesc();
                (void)ge::AttrUtils::GetListStr(opDest, "original_op_names", originalOpNames);
                for (const std::string &originalName : originalOpNames) {
                    if (nodeName == originalName) {
                        node = &tNode;
                        return hiai::SUCCESS;
                    }
                }
                return hiai::SUCCESS;
            });
        }
        if (node == nullptr) {
            FMK_LOGE("Node:%s has quant params, but not in the graph.", nodeName.c_str());
            return false;
        }
        targetNodes[nodeName] = node;
    }

    return true;
}

Status SetQuantizeInfosV2(ModelLightWeightParams& params, int64_t weightDataAddr, ge::ComputeGraphPtr& graph)
{
    std::map<std::string, ge::Node*> targetNodes;
    HIAI_EXPECT_TRUE(CacheAllQuantizeNodes(*graph, params.quantizeV2Configs, targetNodes));

    bool hasSeted = false;
    for (auto& iter : targetNodes) {
        Node* node = iter.second;
        std::string nodeName = iter.first;
        QuantizeV2Config& quantizeV2Config = params.quantizeV2Configs[nodeName];
        if (QuantizeSaver::SaveOpQuantV2Params(quantizeV2Config, node, weightDataAddr) != hiai::SUCCESS) {
            FMK_LOGE("Node: %s save quantize parameters failed.", nodeName.c_str());
            return hiai::FAILED;
        }

        if (quantizeV2Config.isOneSideQuantize && !hasSeted) {
            // 设置单边量化标志
            (void)QuantizeUtil::SetOneSideQuantize(*graph, true);
            hasSeted = true;
        }
        params.quantizeV2Configs.erase(nodeName);
    }

    return hiai::SUCCESS;
}

Status SetQuantizeInfos(ModelLightWeightParams& params, ge::ComputeGraphPtr graph)
{
    bool hasMergedWeight = graph->HasAttr(SRC_MERGED_WEIGHT_ADDR) && graph->HasAttr(SRC_MERGED_WEIGHT_SIZE);
    int64_t weightDataSize = 0;
    int64_t weightDataAddr = 0;
    if (hasMergedWeight) {
        (void)ge::AttrUtils::GetInt(graph, SRC_MERGED_WEIGHT_ADDR, weightDataAddr);
        (void)ge::AttrUtils::GetInt(graph, SRC_MERGED_WEIGHT_SIZE, weightDataSize);
    }
    FMK_LOGI("Use MergedWeight:%d, weightDataSize:%d", hasMergedWeight, weightDataSize);
    if (params.version == "") {
        return SetQuantizeInfosV1(params, weightDataAddr, graph);
    } else {
        return SetQuantizeInfosV2(params, weightDataAddr, graph);
    }
}

Status LoadQuantizeConfigs(ModelLightWeightParams& params, ge::ComputeGraphPtr graph)
{
    // 量化配置为空，不做处理
    if (params.quantizeConfigs.size() == 0 && params.quantizeV2Configs.size() == 0) {
        FMK_LOGW("Quantize config file has no quant infos, please check.");
        return hiai::FAILED;
    }

    if (!CheckQuantizeInfos(params, graph)) {
        FMK_LOGE("Check quantize infos failed.");
        return hiai::FAILED;
    }

    if (SetQuantizeInfos(params, graph) != hiai::SUCCESS) {
        FMK_LOGE("Set quantize info failed.");
        return hiai::FAILED;
    }

    return SUCCESS;
}

void RemoveFusionInfos(ge::ComputeGraph& graph)
{
    auto removeFusionInfo = [&](Node& node) {
        OpDesc& opDesc = node.ROLE(NodeSpec).OpDesc();

        if (!QuantizeUtil::IsSupportQuantOpType(opDesc.GetType())) {
            return hiai::SUCCESS;
        }
        if (QuantizeUtil::HasTransScale(opDesc)) {
            QuantizeUtil::DelTransScale(opDesc);
        }
        if (QuantizeUtil::HasPowerTransScale(opDesc)) {
            QuantizeUtil::DelPowerTransScale(opDesc);
        }
        return hiai::SUCCESS;
    };
    (void)graph.ROLE(GraphListWalker).WalkAllNodes(std::move(removeFusionInfo));
}
} // namespace

Status QuantizeOptimizer::Optimize(const char* file, ge::ComputeGraphPtr graph)
{
    HIAI_EXPECT_NOT_NULL_R(graph, hiai::PARAM_INVALID);
    if (file == nullptr || *file == '\0') {
        // 删除融合添加的与量化相关的属性信息
        RemoveFusionInfos(*graph);
        return hiai::SUCCESS;
    }

    string filePath = file;
    ModelLightWeightParams params;
    if (QuantizeCfgParser::ParseConfigFromFile(file, params) != hiai::SUCCESS) {
        FMK_LOGE("Parse quantize config failed.");
        return hiai::FAILED;
    }
    if (LoadQuantizeConfigs(params, graph) != hiai::SUCCESS) {
        FMK_LOGE("Load quantize config fail.");
        return hiai::FAILED;
    }
    RemoveFusionInfos(*graph);

    return graph->ROLE(GraphSorter).SortNodesDFS();
}
Status QuantizeOptimizer::Optimize(ge::BaseBuffer& quantizeBuffer, ge::ComputeGraphPtr graph)
{
    ModelLightWeightParams params;
    if (QuantizeCfgParser::ParseConfigFromBuffer(quantizeBuffer.data(), quantizeBuffer.size(), params) !=
        hiai::SUCCESS) {
        FMK_LOGE("Parse quantize config failed.");
        return hiai::FAILED;
    }

    return LoadQuantizeConfigs(params, graph);
}

namespace {
const int8_t NOT_SUPPORT_QUANTIZE = 0;
const int8_t QUANTIZE_INT8 = 1;
enum AttrIndex {
    X_QUANT_TYPE_INDEX = 0,
    W_QUANT_TYPE_INDEX,
    X_QUANT_SCALE_INDEX,
    X_QUANT_OFFSET_INDEX,
    W_QUANT_SCALES_INDEX
};
const vector<string> CONV_QUANT_ATTRS { hiai::op::QuantizedConvolution::x_quant_type,
    hiai::op::QuantizedConvolution::filter_quant_type, hiai::op::QuantizedConvolution::x_quant_scale,
    hiai::op::QuantizedConvolution::x_quant_offset, hiai::op::QuantizedConvolution::filter_quant_scales };

const vector<string> FC_QUANT_ATTRS { hiai::op::QuantizedFullyConnection::x_quant_type,
    hiai::op::QuantizedFullyConnection::w_quant_type, hiai::op::QuantizedFullyConnection::x_quant_scale,
    hiai::op::QuantizedFullyConnection::x_quant_offset, hiai::op::QuantizedFullyConnection::w_quant_scales };

const vector<string> MATMUL_QUANT_ATTRS { hiai::op::QuantizedMatMul::x1_quant_type,
    hiai::op::QuantizedMatMul::x2_quant_type, hiai::op::QuantizedMatMul::x1_quant_scale,
    hiai::op::QuantizedMatMul::x1_quant_offset, hiai::op::QuantizedMatMul::x2_quant_scales };

bool GetQuantDataType(ge::OpDesc& opDesc, const string& dataTypeAttr, ge::DataType& dataType)
{
    int64_t quantType = 0;
    if (!ge::AttrUtils::GetInt(opDesc, dataTypeAttr, quantType)) {
        FMK_LOGW("Op %s could not get xQuantType from opDesc.", opDesc.GetName().c_str());
        return false;
    }
    if (quantType != QUANTIZE_INT8) {
        FMK_LOGE("Op %s quantize input data type:%lld is not supported.", opDesc.GetName().c_str(), quantType);
        return false;
    }

    dataType = DT_INT8;
    return true;
}

bool GetQuantizeInfoFromIR(ge::OpDesc& opDesc, const vector<string>& quantAttrs, QuantizeConfig& quantizeConfig)
{
    if (quantAttrs.size() <= W_QUANT_SCALES_INDEX) {
        FMK_LOGW("quantAttrs size less than 5");
        return false;
    }

    ge::DataType inputDType;
    if (!GetQuantDataType(opDesc, quantAttrs[X_QUANT_TYPE_INDEX], inputDType)) {
        return false;
    }
    inputDType = (inputDType == DT_INT8) ? DT_UINT8 : inputDType; // IR定义方式INT8量化输入只支持u8
    ge::DataType weightDType;
    if (!GetQuantDataType(opDesc, quantAttrs[W_QUANT_TYPE_INDEX], weightDType)) {
        return false;
    }

    float xScaleData = 0.0;
    if (!ge::AttrUtils::GetFloat(opDesc, quantAttrs[X_QUANT_SCALE_INDEX], xScaleData)) {
        FMK_LOGE("Op %s could not get xScaleData from opDesc.", opDesc.GetName().c_str());
        return false;
    }
    if (xScaleData <= 0.0) {
        FMK_LOGE("Op %s xScaleData %f is less or equal with 0.", opDesc.GetName().c_str(), xScaleData);
        return false;
    }
    int32_t xOffsetData = 0;
    if (!ge::AttrUtils::GetInt(opDesc, quantAttrs[X_QUANT_OFFSET_INDEX], xOffsetData)) {
        FMK_LOGE("Op %s could not get xOffsetData from opDesc.", opDesc.GetName().c_str());
        return false;
    }

    vector<float> wQuantScales;
    if (!ge::AttrUtils::GetListFloat(opDesc, quantAttrs[W_QUANT_SCALES_INDEX], wQuantScales)) {
        FMK_LOGE("Op %s could not get wQuantScales from opDesc.", opDesc.GetName().c_str());
        return false;
    }
    if (wQuantScales.size() == 0) {
        FMK_LOGW("Op %s wQuantScales vector size is 0.", opDesc.GetName().c_str());
        return false;
    }
    // 构造weight OffSet
    vector<float> wOffsetData(wQuantScales.size(), 0.0);

    quantizeConfig.inputDataType = inputDType;
    quantizeConfig.weightDataType = weightDType;
    quantizeConfig.inputScale.push_back(xScaleData);
    quantizeConfig.inputOffset.push_back(xOffsetData);
    quantizeConfig.weightScale.swap(wQuantScales);
    quantizeConfig.weightOffset.swap(wOffsetData);
    return true;
}

bool IsIRSupportQuantize(const ge::Node& node)
{
    auto& opDesc = node.ROLE(NodeSpec).OpDesc();
    vector<string> quantAttrs(CONV_QUANT_ATTRS);
    if (opDesc.GetType() == hiai::op::FullyConnection::TYPE ||
        opDesc.GetType() == hiai::op::GemmD::TYPE) {
        quantAttrs = FC_QUANT_ATTRS;
    } else if (opDesc.GetType() == hiai::op::MatMul::TYPE) {
        quantAttrs = MATMUL_QUANT_ATTRS;
    }
    int64_t xQuantType = 0;
    if (!opDesc.HasAttr(quantAttrs[X_QUANT_TYPE_INDEX])) {
        return false;
    }
    (void)ge::AttrUtils::GetInt(opDesc, quantAttrs[X_QUANT_TYPE_INDEX], xQuantType);
    int64_t wQuantType = 0;
    if (!ge::AttrUtils::GetInt(opDesc, quantAttrs[W_QUANT_TYPE_INDEX], wQuantType)) {
        return false;
    }
    if (xQuantType == NOT_SUPPORT_QUANTIZE && wQuantType == NOT_SUPPORT_QUANTIZE) {
        return false;
    }
    // 数据类型校验
    vector<ge::TensorPtr> weights = ge::OpDescUtils::MutableWeights(node);
    size_t weightsSize = weights.size();
    if (weightsSize < 1) {
        FMK_LOGE("Op %s Weight size is less then 1.", opDesc.GetName().c_str());
        return false;
    }
    ge::TensorPtr filter = QuantizeUtil::GetFilterTensor(&node);
    HIAI_EXPECT_NOT_NULL_R(filter, false);
    if (filter->GetTensorDesc().GetDataType() != ge::DT_INT8) {
        FMK_LOGW("Op %s filter type is not int8.", opDesc.GetName().c_str());
        return false;
    }
    if (weightsSize > 1) {
        ge::TensorPtr bias = weights[1];
        HIAI_EXPECT_NOT_NULL_R(bias, false);
        if (bias->GetTensorDesc().GetDataType() != ge::DT_INT32) {
            FMK_LOGW("Op %s bias type is not int32.", opDesc.GetName().c_str());
            return false;
        }
    }
    return true;
}

Status CheckIRQuantDataType(const vector<TensorPtr>& weights, ge::DataType weightDataType)
{
    if (weightDataType == DT_INT8) {
        if (weights.size() > 0) {
            if (weights[0]->GetTensorDesc().GetDataType() != DT_INT8) {
                FMK_LOGE("filter type:%d must be int8", weights[0]->GetTensorDesc().GetDataType());
                return hiai::FAILED;
            }
        }
        if (weights.size() > 1) {
            if (weights[1]->GetTensorDesc().GetDataType() != ge::DT_INT32) {
                FMK_LOGE("bias type:%d must be int32", weights[1]->GetTensorDesc().GetDataType());
                return hiai::FAILED;
            }
        }
    } else {
        return hiai::FAILED;
    }
    return hiai::SUCCESS;
}

bool GenerateQuantizeConfig(ge::OpDesc& opDesc, QuantizeConfig& quantizeConfig)
{
    vector<string> quantAttrs(CONV_QUANT_ATTRS);
    if (opDesc.GetType() == hiai::op::FullyConnection::TYPE ||
        opDesc.GetType() == hiai::op::GemmD::TYPE) {
        quantAttrs = FC_QUANT_ATTRS;
    } else if (opDesc.GetType() == hiai::op::MatMul::TYPE) {
        quantAttrs = MATMUL_QUANT_ATTRS;
    }

    if (!(GetQuantizeInfoFromIR(opDesc, quantAttrs, quantizeConfig))) {
        return false;
    }
    return true;
}

Status SaveOpQuantizeInfo(ge::OpDesc& opDesc, const QuantizeConfig& quantizeConfig, ge::TensorPtr filter)
{
    QuantizeInfo quantizeInfo;
    const float* inputScale = quantizeConfig.inputScale.data();
    HIAI_EXPECT_NOT_NULL_R(inputScale, hiai::PARAM_INVALID);
    const float* inputOffset = quantizeConfig.inputOffset.data();
    HIAI_EXPECT_NOT_NULL_R(inputOffset, hiai::PARAM_INVALID);
    quantizeInfo.set_scale_data_value(*inputScale);
    quantizeInfo.set_offset_data_value(*inputOffset);
    quantizeInfo.set_quantize_algo(HALF_OFFSET_ALGO);
    if (quantizeConfig.weightScale.size() > 1) {
        quantizeInfo.set_scale_weight_mode(VECTOR_SCALE);
    } else {
        quantizeInfo.set_scale_weight_mode(SCALAR_SCALE);
    }
    quantizeInfo.set_scale_weight_value(
        quantizeConfig.weightScale.data(), quantizeConfig.weightScale.size() * sizeof(float));
    quantizeInfo.set_offset_weight_value(
        quantizeConfig.weightOffset.data(), quantizeConfig.weightOffset.size() * sizeof(float));

    if (QuantizeUtil::SetQuantizeInfo(opDesc, quantizeInfo) != ge::GRAPH_SUCCESS) {
        FMK_LOGE("Set quantize info failed, op name:%s", opDesc.GetName().c_str());
        return hiai::FAILED;
    }
    (void)ge::AttrUtils::SetBool(opDesc, "CONV_QUANTIZE_FLAG", true);
    // 老量化IR仅支持U8S8量化
    HIAI_EXPECT_EXEC(QuantizeUtil::SetQuantType(opDesc, static_cast<int64_t>(UINT8_INT8_QUANTIZED)));
    filter->MutableTensorDesc().SetDataType(quantizeConfig.weightDataType);
    return hiai::SUCCESS;
}
} // namespace

Status QuantizeOptimizer::SetIRQuantizeInfos(ge::ComputeGraphPtr graph)
{
    HIAI_EXPECT_NOT_NULL(graph);

    auto nodeVisitor = [](ge::Node& node) {
        auto& opDesc = node.ROLE(NodeSpec).OpDesc();

        bool isSupportType = QuantizeUtil::IsSupportQuantOpType(opDesc.GetType());
        bool isIRSupport = IsIRSupportQuantize(node);
        if (!isSupportType || !isIRSupport) {
            return hiai::SUCCESS;
        }

        vector<TensorPtr> weightsPtr = OpDescUtils::MutableWeights(node);
        if (weightsPtr.size() == 0) {
            FMK_LOGE("No weight in op:%s.", opDesc.GetName().c_str());
            return hiai::FAILED;
        }
        uint32_t filterIndex = QuantizeUtil::GetFilterIndex(&node);
        TensorPtr filter = weightsPtr[filterIndex];
        HIAI_EXPECT_NOT_NULL(filter);

        QuantizeConfig quantizeConfig;
        if (!GenerateQuantizeConfig(opDesc, quantizeConfig)) {
            FMK_LOGE("op:%s IR config is invalid.", opDesc.GetName().c_str());
            return hiai::FAILED;
        }
        // 量化类型是INT8, 最新IR定义，filter权值数据类型必须是int8, bias必须是int32
        if (CheckIRQuantDataType(weightsPtr, quantizeConfig.weightDataType) != hiai::SUCCESS) {
            FMK_LOGE("Check weight data type fail, op name:%s", opDesc.GetName().c_str());
            return hiai::FAILED;
        }
        if (QuantizeUtil::CheckOpQuantizeInfo(node, quantizeConfig) != hiai::SUCCESS) {
            FMK_LOGE("Check quantize info failed, op name:%s", node.ROLE(NodeSpec).Name().c_str());
            return hiai::FAILED;
        }
        if (SaveOpQuantizeInfo(opDesc, quantizeConfig, filter) != hiai::SUCCESS) {
            FMK_LOGE("Check quantize info failed, op name:%s", opDesc.GetName().c_str());
            return hiai::FAILED;
        }
        return hiai::SUCCESS;
    };

    return graph->ROLE(GraphListWalker).WalkAllNodes(std::move(nodeVisitor));
}
} // namespace hiai
